{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: torch==2.2 in /opt/conda/lib/python3.10/site-packages (2.2.0)\n",
      "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.13.1)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (4.9.0)\n",
      "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (1.12)\n",
      "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.8.5)\n",
      "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2023.12.2)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /opt/conda/lib/python3.10/site-packages (from torch==2.2) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /opt/conda/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch==2.2) (12.3.101)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch==2.2) (2.1.3)\n",
      "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch==2.2) (1.3.0)\n",
      "Collecting transformers==4.36.0\n",
      "  Downloading transformers-4.36.0-py3-none-any.whl.metadata (126 kB)\n",
      "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m126.8/126.8 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (3.13.1)\n",
      "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (0.20.1)\n",
      "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (1.26.2)\n",
      "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (21.3)\n",
      "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (6.0)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (2022.7.9)\n",
      "Requirement already satisfied: requests in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (2.28.1)\n",
      "Requirement already satisfied: tokenizers<0.19,>=0.14 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (0.15.2)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (0.4.1)\n",
      "Requirement already satisfied: tqdm>=4.27 in /opt/conda/lib/python3.10/site-packages (from transformers==4.36.0) (4.66.1)\n",
      "Requirement already satisfied: fsspec>=2023.5.0 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.36.0) (2023.12.2)\n",
      "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.10/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers==4.36.0) (4.9.0)\n",
      "Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/conda/lib/python3.10/site-packages (from packaging>=20.0->transformers==4.36.0) (3.0.9)\n",
      "Requirement already satisfied: charset-normalizer<3,>=2 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (2.1.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (3.3)\n",
      "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (1.26.10)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests->transformers==4.36.0) (2022.6.15)\n",
      "Downloading transformers-4.36.0-py3-none-any.whl (8.2 MB)\n",
      "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.2/8.2 MB\u001b[0m \u001b[31m47.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0mm\n",
      "\u001b[?25hInstalling collected packages: transformers\n",
      "  Attempting uninstall: transformers\n",
      "    Found existing installation: transformers 4.36.2\n",
      "    Uninstalling transformers-4.36.2:\n",
      "      Successfully uninstalled transformers-4.36.2\n",
      "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
      "minimagen 0.0.9 requires attrs==21.4.0, but you have attrs 23.2.0 which is incompatible.\n",
      "minimagen 0.0.9 requires filelock==3.7.1, but you have filelock 3.13.1 which is incompatible.\n",
      "minimagen 0.0.9 requires fsspec==2022.5.0, but you have fsspec 2023.12.2 which is incompatible.\n",
      "minimagen 0.0.9 requires huggingface-hub==0.8.1, but you have huggingface-hub 0.20.1 which is incompatible.\n",
      "minimagen 0.0.9 requires numpy==1.23.1, but you have numpy 1.26.2 which is incompatible.\n",
      "minimagen 0.0.9 requires Pillow==9.2.0, but you have pillow 10.2.0 which is incompatible.\n",
      "minimagen 0.0.9 requires tokenizers==0.12.1, but you have tokenizers 0.15.2 which is incompatible.\n",
      "minimagen 0.0.9 requires torch==1.12.0, but you have torch 2.2.0 which is incompatible.\n",
      "minimagen 0.0.9 requires torchvision==0.13.0, but you have torchvision 0.17.0 which is incompatible.\n",
      "minimagen 0.0.9 requires tqdm==4.64.0, but you have tqdm 4.66.1 which is incompatible.\n",
      "minimagen 0.0.9 requires transformers==4.20.1, but you have transformers 4.36.0 which is incompatible.\n",
      "minimagen 0.0.9 requires typing_extensions==4.3.0, but you have typing-extensions 4.9.0 which is incompatible.\u001b[0m\u001b[31m\n",
      "\u001b[0mSuccessfully installed transformers-4.36.0\n"
     ]
    }
   ],
   "source": [
    "!pip install torch==2.2\n",
    "!pip install transformers==4.36.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n",
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n",
      "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
      "  _torch_pytree._register_pytree_node(\n"
     ]
    }
   ],
   "source": [
    "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cc87b3a5a41b49ea8473376e34f6cea9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "473f61fb367e4fd7934081acbe4842dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc9bee52b6cd476bb90d8d45c4848e1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d3fe8d7e999c4970ba8671e9d1885be6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10855abc8d97444d81efbe8371a55d01",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ff1bfb41ed0427595228277c15b295b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "torch.manual_seed(799)\n",
    "tkz = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
    "mdl = GPT2LMHeadModel.from_pretrained('gpt2')\n",
    "ln = 10\n",
    "cue = \"They\"\n",
    "gen = tkz(cue, return_tensors=\"pt\")\n",
    "to_ret = gen[\"input_ids\"][0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "They are not the only ones who are being targeted.\n"
     ]
    }
   ],
   "source": [
    "prv=None\n",
    "for i in range(ln):\n",
    "    outputs = mdl(**gen)\n",
    "    next_token_logits = torch.argmax(outputs.logits[-1, :])\n",
    "    to_ret = torch.cat([to_ret, next_token_logits.unsqueeze(0)])\n",
    "    gen = {\"input_ids\": to_ret}\n",
    "\n",
    "seq = tkz.decode(to_ret)\n",
    "\n",
    "print(seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "They are not the only ones who are being targeted\n"
     ]
    }
   ],
   "source": [
    "ip_ids = tkz.encode(cue, return_tensors='pt')\n",
    "op_greedy = mdl.generate(ip_ids, max_length=ln, pad_token_id=tkz.eos_token_id)\n",
    "seq = tkz.decode(op_greedy[0], skip_special_tokens=True)\n",
    "print(seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "They have a lot of\n",
      "They have a lot to\n",
      "They are not the only\n"
     ]
    }
   ],
   "source": [
    "op_beam = mdl.generate(\n",
    "    ip_ids, \n",
    "    max_length=5, \n",
    "    num_beams=3, \n",
    "    num_return_sequences=3,\n",
    "    pad_token_id=tkz.eos_token_id\n",
    ")\n",
    "\n",
    "for op_beam_cur in op_beam:\n",
    "    print(tkz.decode(op_beam_cur, skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "They are the only ones\n",
      "They have been in the\n",
      "They have a lot of\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    torch.manual_seed(i+10)\n",
    "    op = mdl.generate(\n",
    "        ip_ids, \n",
    "        do_sample=True, \n",
    "        max_length=5, \n",
    "        top_k=2,\n",
    "        pad_token_id=tkz.eos_token_id\n",
    "    )\n",
    "\n",
    "    seq = tkz.decode(op[0], skip_special_tokens=True)\n",
    "    print(seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "They are not the only\n",
      "They are not the only\n",
      "They are not the only\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    torch.manual_seed(i+10)\n",
    "    op_greedy = mdl.generate(ip_ids, max_length=5, pad_token_id=tkz.eos_token_id)\n",
    "    seq = tkz.decode(op_greedy[0], skip_special_tokens=True)\n",
    "    print(seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "They in the jail know\n",
      "They live on a budget\n",
      "They\n",
      "\n",
      "Says\n"
     ]
    }
   ],
   "source": [
    "for i in range(3):\n",
    "    torch.manual_seed(i+10)\n",
    "    op = mdl.generate(\n",
    "        ip_ids, \n",
    "        do_sample=True, \n",
    "        max_length=5, \n",
    "        top_p=0.75, \n",
    "        top_k=0,\n",
    "        pad_token_id=tkz.eos_token_id\n",
    "    )\n",
    "\n",
    "    seq = tkz.decode(op[0], skip_special_tokens=True)\n",
    "    print(seq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (Local)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
